{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import cv2\n",
    "import random\n",
    "import time\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据集是mnist，28\\*28，这里选择提取HOG特征，方向梯度直方图（Histogram of Oriented Gradient, HOG）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "raw_data = pd.read_csv('../data/train.csv',header=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>pixel0</th>\n",
       "      <th>pixel1</th>\n",
       "      <th>pixel2</th>\n",
       "      <th>pixel3</th>\n",
       "      <th>pixel4</th>\n",
       "      <th>pixel5</th>\n",
       "      <th>pixel6</th>\n",
       "      <th>pixel7</th>\n",
       "      <th>pixel8</th>\n",
       "      <th>...</th>\n",
       "      <th>pixel774</th>\n",
       "      <th>pixel775</th>\n",
       "      <th>pixel776</th>\n",
       "      <th>pixel777</th>\n",
       "      <th>pixel778</th>\n",
       "      <th>pixel779</th>\n",
       "      <th>pixel780</th>\n",
       "      <th>pixel781</th>\n",
       "      <th>pixel782</th>\n",
       "      <th>pixel783</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 785 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   label  pixel0  pixel1  pixel2  pixel3  pixel4  pixel5  pixel6  pixel7  \\\n",
       "0      1       0       0       0       0       0       0       0       0   \n",
       "1      0       0       0       0       0       0       0       0       0   \n",
       "2      1       0       0       0       0       0       0       0       0   \n",
       "3      4       0       0       0       0       0       0       0       0   \n",
       "4      0       0       0       0       0       0       0       0       0   \n",
       "\n",
       "   pixel8    ...     pixel774  pixel775  pixel776  pixel777  pixel778  \\\n",
       "0       0    ...            0         0         0         0         0   \n",
       "1       0    ...            0         0         0         0         0   \n",
       "2       0    ...            0         0         0         0         0   \n",
       "3       0    ...            0         0         0         0         0   \n",
       "4       0    ...            0         0         0         0         0   \n",
       "\n",
       "   pixel779  pixel780  pixel781  pixel782  pixel783  \n",
       "0         0         0         0         0         0  \n",
       "1         0         0         0         0         0  \n",
       "2         0         0         0         0         0  \n",
       "3         0         0         0         0         0  \n",
       "4         0         0         0         0         0  \n",
       "\n",
       "[5 rows x 785 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(42000, 785)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_data.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "两个冒号的语法：\n",
    "    seq[start:end:step]\n",
    "原来是\n",
    "    imgs = data[0::,1::]\n",
    "    labels = data[::,0]\n",
    "没必要这样写"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data = raw_data.values\n",
    "imgs = data[:, 1:]\n",
    "labels = data[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(42000, 784)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imgs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 利用opencv获取图像hog特征\n",
    "def get_hog_features(trainset):\n",
    "    features = []\n",
    "\n",
    "    hog = cv2.HOGDescriptor('../hog.xml')\n",
    "\n",
    "    for img in trainset:\n",
    "        img = np.reshape(img,(28,28))\n",
    "        cv_img = img.astype(np.uint8)\n",
    "\n",
    "        hog_feature = hog.compute(cv_img)\n",
    "        # hog_feature = np.transpose(hog_feature)\n",
    "        features.append(hog_feature)\n",
    "\n",
    "    features = np.array(features)\n",
    "    features = np.reshape(features,(-1,324))\n",
    "\n",
    "    return features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "features = get_hog_features(imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(42000, 324)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "features.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(42000,)"
      ]
     },
     "execution_count": 112,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33, random_state=23323)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预测\n",
    "\n",
    "因为knn不需要训练，我们可以直接进行预测。不过因为4万个数据即使是预测也非常花时间，这里只取前100个样本做训练集，去30个样本做测试集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "testset, trainset, train_labels = test_features[:30], train_features[:100], train_labels[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "k = 10 # 最近的10个点\n",
    "\n",
    "predict = []\n",
    "count = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.0"
      ]
     },
     "execution_count": 122,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算两个点的欧氏距离\n",
    "np.linalg.norm(np.array([0, 3]) - np.array([4, 0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "time_1 = time.time()\n",
    "\n",
    "for test_vec in testset:\n",
    "    # 输出当前运行的测试用例坐标，用于测试\n",
    "    count += 1\n",
    "    if count % 5000 == 0:\n",
    "        print(count)\n",
    "        \n",
    "    knn_list = np.zeros((1, 2))    # 初始化，存放当前k个最近邻居\n",
    "    \n",
    "    # 先将前k个点放入k个最近邻居中，填充满knn_list\n",
    "    for i in range(k):\n",
    "        label = train_labels[i]\n",
    "        train_vec = trainset[i]\n",
    "\n",
    "        dist = np.linalg.norm(train_vec - test_vec)         # 计算两个点的欧氏距离\n",
    "        knn_list = np.append(knn_list, [[dist, label]], axis=0)\n",
    "        \n",
    "    # 剩下的点\n",
    "    for i in range(k, len(train_labels)):\n",
    "        label = train_labels[i]\n",
    "        train_vec = trainset[i]\n",
    "\n",
    "        dist = np.linalg.norm(train_vec - test_vec)         # 计算两个点的欧氏距离\n",
    "\n",
    "        # 寻找10个邻近点中距离最远的点\n",
    "        max_index = np.argmax(knn_list[:, 0])\n",
    "        max_dist = np.max(knn_list[:, 0])\n",
    "\n",
    "        # 如果当前k个最近邻居中存在点距离比当前点距离远，则替换\n",
    "        if dist < max_dist:\n",
    "            knn_list[max_index] = [dist, label]\n",
    "            \n",
    "            \n",
    "    # 上面代码计算全部运算完之后，即说明已经找到了离当前test_vec最近的10个train_vec\n",
    "    # 统计选票\n",
    "    class_total = 10\n",
    "    class_count = [0 for i in range(class_total)]\n",
    "    for dist, label in knn_list:\n",
    "        class_count[int(label)] += 1\n",
    "\n",
    "    # 找出最大选票数\n",
    "    label_max = max(class_count)\n",
    "\n",
    "    # 最大选票数对应的class\n",
    "    predict.append(class_count.index(label_max))\n",
    "\n",
    "time_2 = time.time()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train time is 0.07612895965576172\n"
     ]
    }
   ],
   "source": [
    "print('train time is %s' % (time_2 - time_1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train time is 3\n"
     ]
    }
   ],
   "source": [
    "print('train time is %s' % (5-2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.        ,  0.        ],\n",
       "       [ 1.10036302,  3.        ],\n",
       "       [ 1.09803486,  3.        ],\n",
       "       [ 1.09235775,  3.        ],\n",
       "       [ 1.03992426,  3.        ],\n",
       "       [ 1.04467952,  3.        ],\n",
       "       [ 1.06501627,  3.        ],\n",
       "       [ 0.93764162,  3.        ],\n",
       "       [ 1.05351973,  3.        ],\n",
       "       [ 1.04691565,  3.        ],\n",
       "       [ 0.9816038 ,  3.        ]])"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([], dtype=float64)"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_list = np.array([])     # 当前k个最近邻居\n",
    "    \n",
    "# 先将前k个点放入k个最近邻居中，填充满knn_list\n",
    "for i in range(k):\n",
    "    label = train_labels[i]\n",
    "    train_vec = trainset[i]\n",
    "\n",
    "    dist = np.linalg.norm(train_vec - test_vec)         # 计算两个点的欧氏距离\n",
    "    knn_list_test = np.append(knn_list_test, [[8.5, 9]], axis=0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试用\n",
    "\n",
    "下面自己写一个寻找10个领近点中距离最远的点："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.,  0.]])"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_list = np.zeros((1, 2))  # 当前k个最近邻居\n",
    "knn_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "all the input array dimensions except for the concatenation axis must match exactly",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-94-9db999664314>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mknn_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m8.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m9\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/Users/xu/anaconda/envs/py35/lib/python3.5/site-packages/numpy/lib/function_base.py\u001b[0m in \u001b[0;36mappend\u001b[0;34m(arr, values, axis)\u001b[0m\n\u001b[1;32m   5145\u001b[0m         \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   5146\u001b[0m         \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 5147\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m: all the input array dimensions except for the concatenation axis must match exactly"
     ]
    }
   ],
   "source": [
    "np.append(knn_list, [[8.5, 9]], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 2.3,  1. ],\n",
       "       [ 3.5,  1. ],\n",
       "       [ 1.5,  4. ],\n",
       "       [ 6.5,  2. ],\n",
       "       [ 5.5,  8. ]])"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_list_test = np.array([[2.3, 1], [3.5, 1], [1.5, 4], [6.5, 2], [5.5, 8]])\n",
    "# 每个元组里，第一个是距离，第二个是对应标签\n",
    "knn_list_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2.3,  3.5,  1.5,  6.5,  5.5])"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_list_test[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "knn_list_test[2] = [9.5, 5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 2.3,  1. ],\n",
       "       [ 3.5,  1. ],\n",
       "       [ 9.5,  5. ],\n",
       "       [ 6.5,  2. ],\n",
       "       [ 5.5,  8. ]])"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_list_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "要想给一个ndarray添加一个元素，必须是同样的格式，即必须是`[[8.5, 9]]`，不能使`[8.5, 9]`，而且必须要用axis指定才行。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 2.3,  1. ],\n",
       "       [ 3.5,  1. ],\n",
       "       [ 9.5,  5. ],\n",
       "       [ 6.5,  2. ],\n",
       "       [ 5.5,  8. ],\n",
       "       [ 8.5,  9. ],\n",
       "       [ 8.5,  9. ]])"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.append(knn_list_test, [[8.5, 9]], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 2.3,  1. ],\n",
       "       [ 3.5,  1. ],\n",
       "       [ 9.5,  5. ],\n",
       "       [ 6.5,  2. ],\n",
       "       [ 5.5,  8. ],\n",
       "       [ 8.5,  9. ]])"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_list_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_list_test[:, 0].argmax()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([], dtype=float64)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 输出评分\n",
    "\n",
    "统计结束后，得到predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 125,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_predict = np.array(predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "score = accuracy_score(test_labels[:30], test_predict)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6333333333333333"
      ]
     },
     "execution_count": 129,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [py35]",
   "language": "python",
   "name": "Python [py35]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
